#!/usr/bin/env python

import os
import json
import pickle
import sys
import traceback
import pandas as pd
import numpy as np
import random
import datetime
from pathlib import Path
import logging
import torch
import shutil
from transformers import AutoTokenizer

from fast_bert.data_cls import BertDataBunch
from fast_bert.learner_cls import BertLearner
from fast_bert.metrics import (
    accuracy,
    accuracy_multilabel,
    accuracy_thresh,
    fbeta,
    roc_auc,
)

run_start_time = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S")

channel_name = "training"

prefix = "/opt/ml/"
input_path = prefix + "input/data"  # opt/ml/input/data
code_path = prefix + "code"  # opt/ml/code
pretrained_model_path = (
    code_path + "/pretrained_models"
)  # opt/ml/code/pretrained_models

finetuned_path = input_path + "/{}/finetuned".format(
    channel_name
)  # opt/ml/input/data/training/finetuned

output_path = os.path.join(prefix, "output")  # opt/ml/output
model_path = os.path.join(prefix, "model")  # opt/ml/model

training_config_path = os.path.join(
    input_path, "{}/config".format(channel_name)
)  # opt/ml/input/data/training/config

hyperparam_path = os.path.join(
    prefix, "input/config/hyperparameters.json"
)  # opt/ml/input/config/hyperparameters.json
config_path = os.path.join(
    training_config_path, "training_config.json"
)  # opt/ml/input/data/training/config/training_config.json


# This algorithm has a single channel of input data called 'training'. Since we run in
# File mode, the input files are copied to the directory specified here.

training_path = os.path.join(input_path, channel_name)  # opt/ml/input/data/training


def searching_all_files(directory: Path):
    file_list = []  # A list for storing files existing in directories

    for x in directory.iterdir():
        if x.is_file():
            file_list.append(str(x))
        else:
            file_list.append(searching_all_files(x))

    return file_list


# The function to execute the training.
def train():

    print("Starting the training.")

    DATA_PATH = Path(training_path)
    LABEL_PATH = Path(training_path)

    try:
        print(config_path)
        with open(config_path, "r") as f:
            training_config = json.load(f)
            print(training_config)

        with open(hyperparam_path, "r") as tc:
            hyperparameters = json.load(tc)
            print(hyperparameters)

        # convert string bools to booleans
        training_config["multi_label"] = training_config["multi_label"] == "True"
        training_config["fp16"] = training_config["fp16"] == "True"
        training_config["text_col"] = training_config.get("text_col", "text")
        training_config["label_col"] = training_config.get("label_col", "label")
        training_config["train_file"] = training_config.get("train_file", "train.csv")
        training_config["val_file"] = training_config.get("val_file", "val.csv")
        training_config["label_file"] = training_config.get("label_file", "labels.csv")
        training_config["random_state"] = training_config.get("random_state", None)

        if training_config["random_state"] is not None:
            print("setting random state {}".format(training_config["random_state"]))
            random_seed(int(training_config["random_state"]))

        # Logger
        # logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, training_config["run_text"]))
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            handlers=[
                # logging.FileHandler(logfile),
                logging.StreamHandler(sys.stdout)
            ],
        )

        logger = logging.getLogger()

        # Define pretrained model path
        PRETRAINED_PATH = Path(pretrained_model_path) / training_config[
            "model_name"
        ].replace("/", ":")
        if PRETRAINED_PATH.is_dir():
            logger.info("model path used {}".format(PRETRAINED_PATH))
            model_name_path = str(PRETRAINED_PATH)
        else:
            model_name_path = training_config["model_name"]
            logger.info(
                "model {} is not preloaded. Will try to download.".format(
                    model_name_path
                )
            )

        finetuned_model_name = training_config.get("finetuned_model", None)
        if finetuned_model_name is not None:
            finetuned_model = os.path.join(finetuned_path, finetuned_model_name)
            logger.info("finetuned model loaded from {}".format(finetuned_model))
        else:
            logger.info(
                "finetuned model not available - loading standard pretrained model"
            )
            finetuned_model = None

        # use auto-tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name_path, use_fast=True)

        device = torch.device("cuda")
        if torch.cuda.device_count() > 1:
            multi_gpu = True
        else:
            multi_gpu = False

        logger.info("Number of GPUs: {}".format(torch.cuda.device_count()))

        if training_config["multi_label"] is True:
            label_col = json.loads(training_config["label_col"])
        else:
            label_col = training_config["label_col"]

        logger.info("label columns: {}".format(label_col))
        test_data = None
        test_df = None
        if training_config.get("test_file", None):
            try:
                test_df = pd.read_csv(DATA_PATH / training_config["test_file"])
            except Exception:
                test_df = pd.read_csv(
                    DATA_PATH / training_config["test_file"], encoding="latin1"
                )
            test_data = list(test_df["text"])
            logger.info("Test file available. Test count {}".format(len(test_df)))

        # Create databunch
        databunch = BertDataBunch(
            DATA_PATH,
            LABEL_PATH,
            tokenizer,
            train_file=training_config["train_file"],
            val_file=training_config["val_file"],
            label_file=training_config["label_file"],
            text_col=training_config["text_col"],
            test_data=test_data,
            label_col=label_col,
            batch_size_per_gpu=int(hyperparameters["train_batch_size"]),
            max_seq_length=int(hyperparameters["max_seq_length"]),
            multi_gpu=multi_gpu,
            multi_label=training_config["multi_label"],
            model_type=training_config["model_type"],
            logger=logger,
            no_cache=True,
        )

        metrics = []
        if training_config["multi_label"] is False:
            metrics.append({"name": "accuracy", "function": accuracy})
        else:
            metrics.append({"name": "accuracy_thresh", "function": accuracy_thresh})
            metrics.append({"name": "roc_auc", "function": roc_auc})
            metrics.append({"name": "fbeta", "function": fbeta})

        logger.info("databunch labels: {}".format(len(databunch.labels)))

        # Initialise the learner
        learner = BertLearner.from_pretrained_model(
            databunch,
            model_name_path,
            metrics=metrics,
            device=device,
            logger=logger,
            output_dir=Path(model_path),
            finetuned_wgts_path=finetuned_model,
            is_fp16=training_config["fp16"],
            fp16_opt_level=training_config["fp16_opt_level"],
            warmup_steps=int(hyperparameters["warmup_steps"]),
            grad_accumulation_steps=int(training_config["grad_accumulation_steps"]),
            multi_gpu=multi_gpu,
            multi_label=training_config["multi_label"],
            logging_steps=int(training_config["logging_steps"]),
        )

        learner.fit(
            int(hyperparameters["epochs"]),
            float(hyperparameters["lr"]),
            schedule_type=hyperparameters["lr_schedule"],
            optimizer_type=hyperparameters["optimizer_type"],
        )

        results = learner.validate(return_preds=True)
        logger.info("y_pred: {}".format(json.dumps(results["y_preds"].tolist())))
        logger.info("y_true: {}".format(json.dumps(results["y_true"].tolist())))
        logger.info("labels: {}".format(json.dumps(databunch.labels)))

        if test_data is not None:
            predictions = learner.predict_batch()
            results = []
            for index, row in test_df.iterrows():
                preds = predictions[index][:3]
                result = {"text": row.get("text"), "ground_truth": row.get("label")}
                for i, pred in enumerate(preds):
                    result["label_{}".format(i + 1)] = pred[0]
                    result["confidence_{}".format(i + 1)] = pred[1]

                results.append(result)

            # save test results with model outcome
            pd.DataFrame(results).to_csv(
                os.path.join(model_path, "test_result.csv"), index=None
            )

            pd.DataFrame(results).to_csv(
                os.path.join(output_path, "test_result.csv"), index=None
            )

        # save model and tokenizer artefacts
        learner.save_model()

        # save model config file
        with open(os.path.join(model_path, "model_config.json"), "w") as f:
            json.dump(training_config, f)

        # save label file
        with open(os.path.join(model_path, "labels.csv"), "w") as f:
            f.write("\n".join(databunch.labels))

        # save label_metadata csv file from LABEL_PATH to model_path
        shutil.copyfile(
            os.path.join(LABEL_PATH, "labels_metadata.json"),
            os.path.join(model_path, "labels_metadata.json"),
        )

    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(os.path.join(output_path, "failure"), "w") as s:
            s.write("Exception during training: " + str(e) + "\n" + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        print("Exception during training: " + str(e) + "\n" + trc, file=sys.stderr)
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)


def random_seed(seed_value):
    random.seed(seed_value)  # Python
    np.random.seed(seed_value)  # cpu vars

    torch.manual_seed(seed_value)  # cpu  vars

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # gpu vars
        torch.backends.cudnn.deterministic = True  # needed
        torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)
